import asyncio
from typing import Type
from cognee.shared.logging_utils import get_logger

from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction.create_dynamic_baml_type import (
    create_dynamic_baml_type,
)
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type_builder import (
    TypeBuilder,
)
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b
from pydantic import BaseModel


logger = get_logger()


async def acreate_structured_output(
    text_input: str, system_prompt: str, response_model: Type[BaseModel]
):
    """
    Generate a response from a user query.

    This method asynchronously creates structured output by sending a request through BAML
    using the provided parameters to generate a completion based on the user input and
    system prompt.

    Parameters:
    -----------

        - text_input (str): The input text provided by the user for generating a response.
        - system_prompt (str): The system's prompt to guide the model's response.
        - response_model (Type[BaseModel]): The expected model type for the response.

    Returns:
    --------

        - BaseModel: A structured output generated by the model, returned as an instance of
          BaseModel.
    """
    config = get_llm_config()

    # Dynamically create BAML response model
    tb = TypeBuilder()
    type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)

    result = await b.AcreateStructuredOutput(
        text_input=text_input,
        system_prompt=system_prompt,
        baml_options={"client_registry": config.baml_registry, "tb": type_builder},
    )

    # Transform BAML response to proper pydantic reponse model
    if response_model is str:
        # Note: when a response model is set to string in python, result is stored in text property in the BAML response model
        return str(result.text)
    return response_model.model_validate(result.dict())


if __name__ == "__main__":
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        from typing import Optional, Dict, Any, List, Literal

        # Models for representing different entities
        class TestModel(BaseModel):
            type: str
            source: Optional[str] = None
            target: Optional[str] = None
            properties: Optional[Dict[str, List[str]]] = None

        loop.run_until_complete(acreate_structured_output("TEST", "THIS IS A TEST", TestModel))
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())
